/*!
 *  Copyright (c) 2020 by Contributors
 * \file dgl/aten/macro.h
 * \brief Common macros for aten package.
 */

#ifndef DGL_ATEN_MACRO_H_
#define DGL_ATEN_MACRO_H_

///////////////////////// Dispatchers //////////////////////////

/*
 * Dispatch according to device:
 *
 * ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
 *   // Now XPU is a placeholder for array->ctx.device_type
 *   DeviceSpecificImplementation<XPU>(...);
 * });
 */
#define ATEN_XPU_SWITCH(val, XPU, op, ...) do {                 \
  if ((val) == kDLCPU) {                                        \
    constexpr auto XPU = kDLCPU;                                \
    {__VA_ARGS__}                                               \
  } else {                                                      \
    LOG(FATAL) << "Operator " << (op) << " does not support "   \
               << dgl::runtime::DeviceTypeCode2Str(val)         \
               << " device.";                                   \
  }                                                             \
} while (0)

/*
 * Dispatch according to device:
 *
 * XXX(minjie): temporary macro that allows CUDA operator
 *
 * ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
 *   // Now XPU is a placeholder for array->ctx.device_type
 *   DeviceSpecificImplementation<XPU>(...);
 * });
 */
#ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do {            \
  if ((val) == kDLCPU) {                                        \
    constexpr auto XPU = kDLCPU;                                \
    {__VA_ARGS__}                                               \
  } else if ((val) == kDLGPU) {                                 \
    constexpr auto XPU = kDLGPU;                                \
    {__VA_ARGS__}                                               \
  } else {                                                      \
    LOG(FATAL) << "Operator " << (op) << " does not support "   \
               << dgl::runtime::DeviceTypeCode2Str(val)         \
               << " device.";                                   \
  }                                                             \
} while (0)
#else  // DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH
#endif  // DGL_USE_CUDA

/*
 * Dispatch according to integral type (either int32 or int64):
 *
 * ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
 *   // Now IdType is the type corresponding to data type in array.
 *   // For instance, one can do this for a CPU array:
 *   DType *data = static_cast<DType *>(array->data);
 * });
 */
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do {            \
  CHECK_EQ((val).code, kDLInt) << "ID must be integer type";  \
  if ((val).bits == 32) {                                     \
    typedef int32_t IdType;                                   \
    {__VA_ARGS__}                                             \
  } else if ((val).bits == 64) {                              \
    typedef int64_t IdType;                                   \
    {__VA_ARGS__}                                             \
  } else {                                                    \
    LOG(FATAL) << "ID can only be int32 or int64";            \
  }                                                           \
} while (0)

/*
 * Dispatch according to bits (either int32 or int64):
 *
 * ATEN_ID_BITS_SWITCH(bits, IdType, {
 *   // Now IdType is the type corresponding to data type in array.
 *   // For instance, one can do this for a CPU array:
 *   DType *data = static_cast<DType *>(array->data);
 * });
 */
#define ATEN_ID_BITS_SWITCH(bits, IdType, ...)                  \
  do {                                                          \
    CHECK((bits) == 32 || (bits) == 64) << "bits must be 32 or 64"; \
    if ((bits) == 32) {                                           \
      typedef int32_t IdType;                                   \
      { __VA_ARGS__ }                                           \
    } else if ((bits) == 64) {                                    \
      typedef int64_t IdType;                                   \
      { __VA_ARGS__ }                                           \
    } else {                                                    \
      LOG(FATAL) << "ID can only be int32 or int64";            \
    }                                                           \
  } while (0)

/*
 * Dispatch according to float type (either float32 or float64):
 *
 * ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, {
 *   // Now FloatType is the type corresponding to data type in array.
 *   // For instance, one can do this for a CPU array:
 *   FloatType *data = static_cast<FloatType *>(array->data);
 * });
 */
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do {  \
  CHECK_EQ((val).code, kDLFloat)                              \
    << (val_name) << " must be float type";                   \
  if ((val).bits == 32) {                                     \
    typedef float FloatType;                                  \
    {__VA_ARGS__}                                             \
  } else if ((val).bits == 64) {                              \
    typedef double FloatType;                                 \
    {__VA_ARGS__}                                             \
  } else {                                                    \
    LOG(FATAL) << (val_name) << " can only be float32 or float64";  \
  }                                                           \
} while (0)

#define ATEN_FLOAT_BITS_SWITCH(val, bits, val_name, ...) do {  \
  CHECK_EQ((val).code, kDLFloat)                              \
    << (val_name) << " must be float type";                   \
  if ((val).bits == 16) {                                     \
    constexpr int bits = 16;                                  \
    {__VA_ARGS__}                                             \
  } else if ((val).bits == 32) {                              \
    constexpr int bits = 32;                                  \
    {__VA_ARGS__}                                             \
  } else if ((val).bits == 64) {                              \
    constexpr int bits = 64;                                  \
    {__VA_ARGS__}                                             \
  } else {                                                    \
    LOG(FATAL) << (val_name) << " can only be float32 or float64";  \
  }                                                           \
} while (0)

/*
 * Dispatch according to data type (int32, int64, float32 or float64):
 *
 * ATEN_DTYPE_SWITCH(array->dtype, DType, {
 *   // Now DType is the type corresponding to data type in array.
 *   // For instance, one can do this for a CPU array:
 *   DType *data = static_cast<DType *>(array->data);
 * });
 */
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do {     \
  if ((val).code == kDLInt && (val).bits == 32) {             \
    typedef int32_t DType;                                    \
    {__VA_ARGS__}                                             \
  } else if ((val).code == kDLInt && (val).bits == 64) {      \
    typedef int64_t DType;                                    \
    {__VA_ARGS__}                                             \
  } else if ((val).code == kDLFloat && (val).bits == 32) {    \
    typedef float DType;                                      \
    {__VA_ARGS__}                                             \
  } else if ((val).code == kDLFloat && (val).bits == 64) {    \
    typedef double DType;                                     \
    {__VA_ARGS__}                                             \
  } else {                                                    \
    LOG(FATAL) << (val_name) << " can only be int32, int64, float32 or float64"; \
  }                                                           \
} while (0)

/*
 * Dispatch according to integral type of CSR graphs.
 * Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
 */
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do {         \
  if ((val).code == kDLInt && (val).bits == 32) {           \
    typedef int32_t DType;                                  \
    {__VA_ARGS__}                                           \
  } else if ((val).code == kDLInt && (val).bits == 64) {    \
    typedef int64_t DType;                                  \
    {__VA_ARGS__}                                           \
  } else {                                                  \
    LOG(FATAL) << "CSR matrix data can only be int32 or int64";  \
  }                                                         \
} while (0)

// Macro to dispatch according to device context and index type.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, op, ...)            \
  ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, op, {   \
    ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {        \
      {__VA_ARGS__}                                           \
    });                                                       \
  });

// Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH(coo, XPU, IdType, op, ...)          \
  ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, op, {    \
    ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, {         \
      {__VA_ARGS__}                                         \
    });                                                     \
  });

// Macro to dispatch according to device context (allowing cuda)
#ifdef DGL_USE_CUDA
#define ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, op, ...)            \
  ATEN_XPU_SWITCH_CUDA((csr).indptr->ctx.device_type, XPU, op, {   \
    ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {             \
      {__VA_ARGS__}                                                \
    });                                                            \
  });

// Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...)               \
  ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, {    \
    ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, {              \
      {__VA_ARGS__}                                              \
    });                                                          \
  });
#else  // DGL_USE_CUDA
#define ATEN_CSR_SWITCH_CUDA ATEN_CSR_SWITCH
#define ATEN_COO_SWITCH_CUDA ATEN_COO_SWITCH
#endif  // DGL_USE_CUDA

///////////////////////// Array checks //////////////////////////

#define IS_INT32(a)  \
  ((a)->dtype.code == kDLInt && (a)->dtype.bits == 32)
#define IS_INT64(a)  \
  ((a)->dtype.code == kDLInt && (a)->dtype.bits == 64)
#define IS_FLOAT32(a)  \
  ((a)->dtype.code == kDLFloat && (a)->dtype.bits == 32)
#define IS_FLOAT64(a)  \
  ((a)->dtype.code == kDLFloat && (a)->dtype.bits == 64)

#define CHECK_IF(cond, prop, value_name, dtype_name) \
  CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name)

#define CHECK_INT32(value, value_name) \
  CHECK_IF(IS_INT32(value), "dtype", value_name, "int32")
#define CHECK_INT64(value, value_name) \
  CHECK_IF(IS_INT64(value), "dtype", value_name, "int64")
#define CHECK_INT(value, value_name) \
  CHECK_IF(IS_INT32(value) || IS_INT64(value), "dtype", value_name, "int32 or int64")
#define CHECK_FLOAT32(value, value_name) \
  CHECK_IF(IS_FLOAT32(value), "dtype", value_name, "float32")
#define CHECK_FLOAT64(value, value_name) \
  CHECK_IF(IS_FLOAT64(value), "dtype", value_name, "float64")
#define CHECK_FLOAT(value, value_name) \
  CHECK_IF(IS_FLOAT32(value) || IS_FLOAT64(value), "dtype", value_name, "float32 or float64")

#define CHECK_NDIM(value, _ndim, value_name) \
  CHECK_IF((value)->ndim == (_ndim), "ndim", value_name, _ndim)

#define CHECK_SAME_DTYPE(VAR1, VAR2)                                            \
  CHECK((VAR1)->dtype == (VAR2)->dtype)                                         \
    << "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) << "("   \
    << (VAR1)->dtype << ")"                                                     \
    << ". But got " << (VAR2)->dtype << ".";

#define CHECK_SAME_CONTEXT(VAR1, VAR2)                                                      \
  CHECK((VAR1)->ctx == (VAR2)->ctx)                                                         \
    << "Expected " << (#VAR2) << " to have the same device context as " << (#VAR1) << "("   \
    << (VAR1)->ctx << ")"                                                                   \
    << ". But got " << (VAR2)->ctx << ".";

#define CHECK_NO_OVERFLOW(dtype, val)                                                  \
  do {                                                                                 \
    if (sizeof(val) == 8 && (dtype).bits == 32)                                        \
      CHECK_LE((val), 0x7FFFFFFFL) << "int32 overflow for argument " << (#val) << "."; \
  } while (0);

#define CHECK_IS_ID_ARRAY(VAR)                                              \
  CHECK((VAR)->ndim == 1 && (IS_INT32(VAR) || IS_INT64(VAR)))               \
    << "Expected argument " << (#VAR) << " to be an 1D integer array.";

#endif  // DGL_ATEN_MACRO_H_
